import logging
import os
from pathlib import Path
import glob
import numpy as np
import torch
import matplotlib.pyplot as plt

# Set printoptions
"""
设置打印选项
torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)
precision是每一个元素的输出精度，默认是八位；
threshold是输出时的阈值，当tensor中元素的个数大于该值时，进行缩略输出，默认时1000；
edgeitems是输出的维度，默认是3；
linewidth字面意思，每一行输出的长度；
profile=None，pretty打印的完全默认值
"""
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5


def set_logging():
    logging.basicConfig(format="%(message)s", level=logging.INFO)


def plot_labels(labels, save_dir=''):
    # plot dataset labels
    c = labels  # classes, boxes
    nc = max(c) + 1  # number of classes
    fig, ax = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True)
    ax.hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
    ax.set_xlabel('classes')
    ax.set_ylabel("number")
    plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
    plt.close()


def increment_dir(dir, comment=''):
    # Increments a directory runs/exp1 --> runs/exp2_comment
    n = 0  # number
    dir = str(Path(dir))  # os-agnostic
    d = sorted(glob.glob(dir + '*'))  # directories
    if len(d):
        n = max([int(x[len(dir):x.find('_') if '_' in x else None]) for x in d]) + 1  # increment
    return dir + str(n) + ('_' + comment if comment else '')


def labels_to_class_weights(labels, nc=80):
    # Get class weights (inverse frequency) from training labels
    if labels[0] is None:  # no labels loaded
        return torch.Tensor()

    classes = labels  # labels = [class xywh]
    weights = np.bincount(classes, minlength=nc)  # occurences per class

    weights = 1 / weights  # number of targets per class
    weights /= weights.sum()  # normalize
    return torch.from_numpy(weights)


def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
    # Produces image weights based on class mAPs
    n = len(labels)
    class_counts = np.array([np.bincount([labels[i]], minlength=nc) for i in range(n)])
    image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
    # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
    return image_weights


def plot_results(save_dir=''):  # from utils.general import *; plot_results()
    start = 0
    stop = 0
    labels = ()
    # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
    fig, ax = plt.subplots(2, 5, figsize=(12, 6))
    ax = ax.ravel()
    s = ['loss', 'accuracy', "precision", "recall", 'f1-scores',
         'val loss', 'val accuracy', "val precision", "val recall", 'val f1-scores']
    files = glob.glob(str(Path(save_dir) / 'results*.txt'))
    for fi, f in enumerate(files):
        try:
            results = np.loadtxt(f, usecols=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ndmin=2).T
            n = results.shape[1]  # number of rows
            x = range(start, min(stop, n) if stop else n)
            for i in range(len(s)):
                y = results[i, x]
                label = labels[fi] if len(labels) else Path(f).stem
                ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
                ax[i].set_title(s[i])
                # if i in [5, 6, 7]:  # share train and val loss y axes
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
        except Exception as e:
            print('Warning: Plotting error for %s; %s' % (f, e))

    fig.tight_layout()
    ax[1].legend()
    fig.savefig(Path(save_dir) / 'results.png', dpi=200)

if __name__ == '__main__':
    plot_results(r"C:\Users\zengxh\Documents\workspace\PyCharm-workspace\PVDefectPlatform\abnormal\runs\23")